{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Classification of Diseased and Healthy 3D Coronary Artery Shapes Using MONAI (Minimal Reproducible Example)"
      ],
      "metadata": {
        "id": "Hp27R4K9tUKs"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "installation"
      ],
      "metadata": {
        "id": "Rw2hgMcrtwKp"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "xL4AU_X4rveB",
        "outputId": "bb919f82-2534-4272-813c-27f8221d964d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting monai\n",
            "  Downloading monai-1.3.0-202310121228-py3-none-any.whl (1.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.20 in /usr/local/lib/python3.10/dist-packages (from monai) (1.23.5)\n",
            "Requirement already satisfied: torch>=1.9 in /usr/local/lib/python3.10/dist-packages (from monai) (2.1.0+cu121)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.13.1)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (4.5.0)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (3.1.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (2023.6.0)\n",
            "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.9->monai) (2.1.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.9->monai) (2.1.3)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.9->monai) (1.3.0)\n",
            "Installing collected packages: monai\n",
            "Successfully installed monai-1.3.0\n",
            "Collecting MedShapeNetCore\n",
            "  Downloading MedShapeNetCore-0.1.0-py3-none-any.whl (6.1 kB)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (1.23.5)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (1.5.3)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (1.2.2)\n",
            "Requirement already satisfied: scikit-image in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (0.19.3)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (4.66.1)\n",
            "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (9.4.0)\n",
            "Collecting fire (from MedShapeNetCore)\n",
            "  Downloading fire-0.5.0.tar.gz (88 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.3/88.3 kB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting trimesh (from MedShapeNetCore)\n",
            "  Downloading trimesh-4.0.9-py3-none-any.whl (689 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m689.0/689.0 kB\u001b[0m \u001b[31m24.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting SimpleITK (from MedShapeNetCore)\n",
            "  Downloading SimpleITK-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m52.7/52.7 MB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting open3d (from MedShapeNetCore)\n",
            "  Downloading open3d-0.18.0-cp310-cp310-manylinux_2_27_x86_64.whl (399.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m399.7/399.7 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (1.11.4)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (3.7.1)\n",
            "Collecting clint (from MedShapeNetCore)\n",
            "  Downloading clint-0.5.1.tar.gz (29 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from MedShapeNetCore) (2.31.0)\n",
            "Collecting argparse (from MedShapeNetCore)\n",
            "  Downloading argparse-1.4.0-py2.py3-none-any.whl (23 kB)\n",
            "Collecting args (from clint->MedShapeNetCore)\n",
            "  Downloading args-0.1.0.tar.gz (3.0 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from fire->MedShapeNetCore) (1.16.0)\n",
            "Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from fire->MedShapeNetCore) (2.4.0)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (1.2.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (4.47.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (1.4.5)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (23.2)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (3.1.1)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->MedShapeNetCore) (2.8.2)\n",
            "Collecting dash>=2.6.0 (from open3d->MedShapeNetCore)\n",
            "  Downloading dash-2.14.2-py3-none-any.whl (10.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.2/10.2 MB\u001b[0m \u001b[31m107.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: werkzeug>=2.2.3 in /usr/local/lib/python3.10/dist-packages (from open3d->MedShapeNetCore) (3.0.1)\n",
            "Requirement already satisfied: nbformat>=5.7.0 in /usr/local/lib/python3.10/dist-packages (from open3d->MedShapeNetCore) (5.9.2)\n",
            "Collecting configargparse (from open3d->MedShapeNetCore)\n",
            "  Downloading ConfigArgParse-1.7-py3-none-any.whl (25 kB)\n",
            "Collecting ipywidgets>=8.0.4 (from open3d->MedShapeNetCore)\n",
            "  Downloading ipywidgets-8.1.1-py3-none-any.whl (139 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m139.4/139.4 kB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting addict (from open3d->MedShapeNetCore)\n",
            "  Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)\n",
            "Requirement already satisfied: pyyaml>=5.4.1 in /usr/local/lib/python3.10/dist-packages (from open3d->MedShapeNetCore) (6.0.1)\n",
            "Collecting pyquaternion (from open3d->MedShapeNetCore)\n",
            "  Downloading pyquaternion-0.9.9-py3-none-any.whl (14 kB)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->MedShapeNetCore) (2023.3.post1)\n",
            "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->MedShapeNetCore) (1.3.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->MedShapeNetCore) (3.2.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->MedShapeNetCore) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->MedShapeNetCore) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->MedShapeNetCore) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->MedShapeNetCore) (2023.11.17)\n",
            "Requirement already satisfied: networkx>=2.2 in /usr/local/lib/python3.10/dist-packages (from scikit-image->MedShapeNetCore) (3.2.1)\n",
            "Requirement already satisfied: imageio>=2.4.1 in /usr/local/lib/python3.10/dist-packages (from scikit-image->MedShapeNetCore) (2.31.6)\n",
            "Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.10/dist-packages (from scikit-image->MedShapeNetCore) (2023.12.9)\n",
            "Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-image->MedShapeNetCore) (1.5.0)\n",
            "Requirement already satisfied: Flask<3.1,>=1.0.4 in /usr/local/lib/python3.10/dist-packages (from dash>=2.6.0->open3d->MedShapeNetCore) (2.2.5)\n",
            "Requirement already satisfied: plotly>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from dash>=2.6.0->open3d->MedShapeNetCore) (5.15.0)\n",
            "Collecting dash-html-components==2.0.0 (from dash>=2.6.0->open3d->MedShapeNetCore)\n",
            "  Downloading dash_html_components-2.0.0-py3-none-any.whl (4.1 kB)\n",
            "Collecting dash-core-components==2.0.0 (from dash>=2.6.0->open3d->MedShapeNetCore)\n",
            "  Downloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)\n",
            "Collecting dash-table==5.0.0 (from dash>=2.6.0->open3d->MedShapeNetCore)\n",
            "  Downloading dash_table-5.0.0-py3-none-any.whl (3.9 kB)\n",
            "Requirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.10/dist-packages (from dash>=2.6.0->open3d->MedShapeNetCore) (4.5.0)\n",
            "Collecting retrying (from dash>=2.6.0->open3d->MedShapeNetCore)\n",
            "  Downloading retrying-1.3.4-py3-none-any.whl (11 kB)\n",
            "Collecting ansi2html (from dash>=2.6.0->open3d->MedShapeNetCore)\n",
            "  Downloading ansi2html-1.9.1-py3-none-any.whl (17 kB)\n",
            "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.10/dist-packages (from dash>=2.6.0->open3d->MedShapeNetCore) (1.5.8)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from dash>=2.6.0->open3d->MedShapeNetCore) (67.7.2)\n",
            "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.10/dist-packages (from dash>=2.6.0->open3d->MedShapeNetCore) (7.0.1)\n",
            "Collecting comm>=0.1.3 (from ipywidgets>=8.0.4->open3d->MedShapeNetCore)\n",
            "  Downloading comm-0.2.1-py3-none-any.whl (7.2 kB)\n",
            "Requirement already satisfied: ipython>=6.1.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets>=8.0.4->open3d->MedShapeNetCore) (7.34.0)\n",
            "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets>=8.0.4->open3d->MedShapeNetCore) (5.7.1)\n",
            "Collecting widgetsnbextension~=4.0.9 (from ipywidgets>=8.0.4->open3d->MedShapeNetCore)\n",
            "  Downloading widgetsnbextension-4.0.9-py3-none-any.whl (2.3 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: jupyterlab-widgets~=3.0.9 in /usr/local/lib/python3.10/dist-packages (from ipywidgets>=8.0.4->open3d->MedShapeNetCore) (3.0.9)\n",
            "Requirement already satisfied: fastjsonschema in /usr/local/lib/python3.10/dist-packages (from nbformat>=5.7.0->open3d->MedShapeNetCore) (2.19.1)\n",
            "Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.10/dist-packages (from nbformat>=5.7.0->open3d->MedShapeNetCore) (4.19.2)\n",
            "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.10/dist-packages (from nbformat>=5.7.0->open3d->MedShapeNetCore) (5.7.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=2.2.3->open3d->MedShapeNetCore) (2.1.3)\n",
            "Requirement already satisfied: Jinja2>=3.0 in /usr/local/lib/python3.10/dist-packages (from Flask<3.1,>=1.0.4->dash>=2.6.0->open3d->MedShapeNetCore) (3.1.2)\n",
            "Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from Flask<3.1,>=1.0.4->dash>=2.6.0->open3d->MedShapeNetCore) (2.1.2)\n",
            "Requirement already satisfied: click>=8.0 in /usr/local/lib/python3.10/dist-packages (from Flask<3.1,>=1.0.4->dash>=2.6.0->open3d->MedShapeNetCore) (8.1.7)\n",
            "Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore)\n",
            "  Downloading jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m54.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (4.4.2)\n",
            "Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (0.7.5)\n",
            "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (3.0.43)\n",
            "Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (2.16.1)\n",
            "Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (0.2.0)\n",
            "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (0.1.6)\n",
            "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (4.9.0)\n",
            "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d->MedShapeNetCore) (23.2.0)\n",
            "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d->MedShapeNetCore) (2023.12.1)\n",
            "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d->MedShapeNetCore) (0.32.1)\n",
            "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat>=5.7.0->open3d->MedShapeNetCore) (0.16.2)\n",
            "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly>=5.0.0->dash>=2.6.0->open3d->MedShapeNetCore) (8.2.3)\n",
            "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata->dash>=2.6.0->open3d->MedShapeNetCore) (3.17.0)\n",
            "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.10/dist-packages (from jupyter-core->nbformat>=5.7.0->open3d->MedShapeNetCore) (4.1.0)\n",
            "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (0.8.3)\n",
            "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (0.7.0)\n",
            "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=6.1.0->ipywidgets>=8.0.4->open3d->MedShapeNetCore) (0.2.12)\n",
            "Building wheels for collected packages: clint, fire, args\n",
            "  Building wheel for clint (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for clint: filename=clint-0.5.1-py3-none-any.whl size=34457 sha256=e6e7ad3591d49b6ab17c707a360c7e8aa3927a70f4327370a3c470d4770e735c\n",
            "  Stored in directory: /root/.cache/pip/wheels/60/bd/a0/c20dd085251d95af656b5e3d287db7d4b4e8aec67b53a6f8dd\n",
            "  Building wheel for fire (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for fire: filename=fire-0.5.0-py2.py3-none-any.whl size=116934 sha256=5c648aee655a281ab4935cef80a9d681a215a371bd0c1957521dd6ba25c2d32c\n",
            "  Stored in directory: /root/.cache/pip/wheels/90/d4/f7/9404e5db0116bd4d43e5666eaa3e70ab53723e1e3ea40c9a95\n",
            "  Building wheel for args (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for args: filename=args-0.1.0-py3-none-any.whl size=3318 sha256=7b0d9b59777d68535ee2f78772eef2bc10703bbafb0d05f117c0a40aaaccde35\n",
            "  Stored in directory: /root/.cache/pip/wheels/18/d7/bc/7b88d8405d97070a1a62712fd639ea0ad8d14b3dd74075cca6\n",
            "Successfully built clint fire args\n",
            "Installing collected packages: SimpleITK, dash-table, dash-html-components, dash-core-components, args, argparse, addict, widgetsnbextension, trimesh, retrying, pyquaternion, jedi, fire, configargparse, comm, clint, ansi2html, ipywidgets, dash, open3d, MedShapeNetCore\n",
            "  Attempting uninstall: widgetsnbextension\n",
            "    Found existing installation: widgetsnbextension 3.6.6\n",
            "    Uninstalling widgetsnbextension-3.6.6:\n",
            "      Successfully uninstalled widgetsnbextension-3.6.6\n",
            "  Attempting uninstall: ipywidgets\n",
            "    Found existing installation: ipywidgets 7.7.1\n",
            "    Uninstalling ipywidgets-7.7.1:\n",
            "      Successfully uninstalled ipywidgets-7.7.1\n",
            "Successfully installed MedShapeNetCore-0.1.0 SimpleITK-2.3.1 addict-2.4.0 ansi2html-1.9.1 argparse-1.4.0 args-0.1.0 clint-0.5.1 comm-0.2.1 configargparse-1.7 dash-2.14.2 dash-core-components-2.0.0 dash-html-components-2.0.0 dash-table-5.0.0 fire-0.5.0 ipywidgets-8.1.1 jedi-0.19.1 open3d-0.18.0 pyquaternion-0.9.9 retrying-1.3.4 trimesh-4.0.9 widgetsnbextension-4.0.9\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "argparse"
                ]
              }
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "!pip install monai\n",
        "!pip install MedShapeNetCore"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "download the dataset"
      ],
      "metadata": {
        "id": "GuyPvbXkvk0q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python -m MedShapeNetCore download ASOCA"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "y5BL5LQTvnCz",
        "outputId": "adbf6313-92aa-407a-f102-bcc08bcea446"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "downloading...\n",
            "[################################] 42842/42842 - 00:00:12\n",
            "download complete...\n",
            "file directory: ./medshapenetcore_npz/medshapenetcore_ASOCA.npz\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "import necessay packages"
      ],
      "metadata": {
        "id": "Z9QKBKFHv1lr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import sys\n",
        "import monai\n",
        "import torch\n",
        "import numpy as np\n",
        "from MedShapeNetCore.MedShapeNetCore import MyDict,MSNLoader,MSNVisualizer,MSNSaver,MSNTransformer\n",
        "pin_memory = torch.cuda.is_available()\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print('using:',device)\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "id": "qtu4zr59tHYd",
        "outputId": "18de5a56-776a-47f2-a839-7299764f8b93"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "using: cuda\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "load and prepare the shape data"
      ],
      "metadata": {
        "id": "35p131DDwzid"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "msn_loader=MSNLoader()\n",
        "ASOCA_DATA=msn_loader.load('ASOCA')\n",
        "shape_data=ASOCA_DATA['mask']\n",
        "shape_labels=ASOCA_DATA['labels']\n",
        "print(shape_data.shape)\n",
        "print(shape_labels)\n",
        "labels = torch.nn.functional.one_hot(torch.as_tensor(shape_labels).to(torch.int64)).float() #  one hot encoding\n",
        "print(labels.shape)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X7-tLTIew2Z8",
        "outputId": "3536e09c-0343-45f3-a171-d161aadd9274"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "current dataset: ./medshapenetcore_npz/medshapenetcore_ASOCA.npz\n",
            "available keys in the dataset: ['mask', 'point', 'mesh', 'labels']\n",
            "(40, 256, 256, 256)\n",
            "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.\n",
            " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
            "torch.Size([40, 2])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "define the optimizer, loss function and a classification model based on DenseNet"
      ],
      "metadata": {
        "id": "mk8oZ5WqykWK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)\n",
        "loss_function = torch.nn.CrossEntropyLoss()\n",
        "optimizer = torch.optim.Adam(model.parameters(), 1e-4)"
      ],
      "metadata": {
        "id": "41ot9r4byrq-"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "train the model for 200 epochs"
      ],
      "metadata": {
        "id": "5fQhCF42zLF3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "max_epochs=200\n",
        "torch.cuda.empty_cache()\n",
        "import gc\n",
        "gc.collect()\n",
        "for epoch in range(max_epochs):\n",
        "    print(f\"epoch {epoch + 1}/{max_epochs}\")\n",
        "    model.train()\n",
        "    epoch_loss = 0\n",
        "    for i in range(len(shape_labels)):\n",
        "        torch.cuda.empty_cache()\n",
        "        inputs = torch.tensor(np.expand_dims(np.expand_dims(shape_data[i],axis=0),axis=0),dtype=torch.float32).to(device)\n",
        "        labels = torch.tensor(np.expand_dims(shape_labels[i],axis=0),dtype=torch.float32).type(torch.LongTensor).to(device)\n",
        "        torch.cuda.empty_cache()\n",
        "        optimizer.zero_grad()\n",
        "        outputs = model(inputs)\n",
        "        loss = 20*loss_function(outputs, labels)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        epoch_loss += loss.item()\n",
        "        print(f\"train_loss: {loss.item():.4f}\")"
      ],
      "metadata": {
        "id": "p92oRSjVzNpY"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}